Instance Segmentation¶

Ci sono diversi modelli allo stato dell'arte che permettono di fare instance segmentation, noi ci soffermeremo su un modello particolarmente effettivo e che viene usato molto frequentemente: le Mask R-CNN progettate da Facebook AI nel 2017.
Tutorial by Francesco Pelosin @ Ca' Foscari University
Mask R-CNN in a nutshell¶
Il task di segmentazione semantica richiede la rilevazione degli oggetti nella scena e la loro segmentazione, di fatto le Mask R-CNN risolvono il problema con un approccio bottom up.
L'architettura infatti è composta da:
- Una rete backbone che funge da object detector (Faster R-CNN)
- Una rete on top che adopera image segmentation (Fully Convolutional Network)

Creazione Modello¶
Grazie al modulo torchvision di PyTorch possiamo usufruire di diversi modelli pretrainati (ossia allenati su un particolare dataset).
import torchvision # torchvision è usato per accedere ai modelli
from torchvision import transforms as T # da torchvision prendiamo transforms per adoperare trasformazioni sulle immagini
import cv2 # cv2 (OpenCV) è una libreria trasversale di algoritmi di computer vision
import matplotlib.pyplot as plt # libreria per disegnare "plottare" grafici
import numpy as np # numpy è una libreria pensata per il calcolo matriciale scientifico
# Download del modello, scarichiamo già i pesi pretrainati sul COCO dataset.
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# Settiamo il modello in "eval" mode in quanto il comportamento durante il training differisce da quello di evaluation
model.eval()
Utilities di visualizzazione ed accesso¶
# Di seguito riportiamo le etichette degli oggetti riconoscibili dalla rete pretrainata sul COCO dataset
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
def get_prediction(img_path, threshold):
"""
Questa funzione ritorna le maschere, le bounding boxes e le classi predette
per una certa immagine ed una confidenza (threshold) da passare al modello.
"""
# Lettura immagine
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Trasformazione in tensore
transform = T.Compose([T.ToTensor()])
img = transform(img)
# Feed dell'immagine alla rete
pred = model([img])
# Confidenza della rete, tieni solo le predizioni sopra una certa soglia (0-1)
pred_score = list(pred[0]['scores'].detach().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
# Mantieni la predizione della maschera sopra una certa soglia di confidenza 0.5 (0-1)
masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
# Estrapoliamo il nome della classe
pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
# Estrapoliamo la bounding box
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
# Shrink delle strutture per matching della predizione
masks = masks[:pred_t+1]
pred_boxes = pred_boxes[:pred_t+1]
pred_class = pred_class[:pred_t+1]
return masks, pred_boxes, pred_class
def random_colour_masks(image):
"""
Questa funzione genera una maschera con un colore random
"""
# Colori codificati in RGB
colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
# Generazione dei canali RGB della maschera
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
# Setting del colore in base alla maschera
r[image == 1], g[image == 1], b[image == 1] = colours[np.random.randint(0,10)]
# Stack dei layers RGB
coloured_mask = np.stack([r, g, b], axis=2)
return coloured_mask
def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
"""
Questa funzione permette di chiamare il modello, avere le predizioni e
generare un overlay di visualizzazione delle predizioni
"""
# Ritorniamo le predizioni
masks, boxes, pred_cls = get_prediction(img_path, threshold)
# Lettura dell'immagine e generazione delle bounding boxes e maschere
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for i in range(len(masks)):
rgb_mask = random_colour_masks(masks[i])
img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
# Visualizzazione dell'immagine
plt.figure(figsize=(20,30))
plt.imshow(img)
plt.axis('off')
plt.show()
Download immagine¶
!wget https://images-na.ssl-images-amazon.com/images/I/A1ppzg2gLwL._SL1500_.jpg -O btles.jpg
img = cv2.imread('./btles.jpg', cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,30))
plt.imshow(img)
plt.axis('off')
plt.show()
Get predictions!¶
instance_segmentation_api('./btles.jpg', threshold=0.8, text_size=1, text_th=2, rect_th=2)
Hands on LAB 👩💻¶
Challenge 1: Teletrasporto spaziale 🌌¶
Siete pronti? In questo hands-on lab prenderemo una foto che ritrae una persona e cercheremo di segmentarla attraverso il nostro modello. Dopodichè scaricheremo una foto dello spazio e cercheremo di telegrasportare la persona direttamente nello spazio.
Scarica una foto se non possiedi una webcam¶
!wget https://www.agenpress.it/wp-content/uploads/2019/10/AP_19201004713022-1000x667.jpg -O photo.jpg
[ Opzionale ] Webcam¶
Se avete una webcam scattate una foto di voi stessi, altrimenti salta alla cella successiva.
#@title 👈 Cattura
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode
def take_photo(filename='photo.jpg', quality=0.8):
js = Javascript('''
async function takePhoto(quality) {
const div = document.createElement('div');
const capture = document.createElement('button');
capture.textContent = 'Cattura';
div.appendChild(capture);
const video = document.createElement('video');
video.style.display = 'block';
const stream = await navigator.mediaDevices.getUserMedia({video: true});
document.body.appendChild(div);
div.appendChild(video);
video.srcObject = stream;
await video.play();
// Resize the output to fit the video element.
google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
// Wait for Capture to be clicked.
await new Promise((resolve) => capture.onclick = resolve);
const canvas = document.createElement('canvas');
canvas.width = video.videoWidth;
canvas.height = video.videoHeight;
canvas.getContext('2d').drawImage(video, 0, 0);
stream.getVideoTracks()[0].stop();
div.remove();
return canvas.toDataURL('image/jpeg', quality);
}
''')
display(js)
data = eval_js('takePhoto({})'.format(quality))
binary = b64decode(data.split(',')[1])
with open(filename, 'wb') as f:
f.write(binary)
return filename
from IPython.display import Image
try:
filename = take_photo()
print('Saved to {}'.format(filename))
# Show the image which was just taken.
display(Image(filename))
except Exception as err:
# Errors will be thrown if the user does not have a webcam or if they do not
# grant the page permission to access it.
print(str(err))
Visualizziamo l'immagine¶
import cv2
import matplotlib.pyplot as plt
# Prendiamo la foto appena scattata e invertiamo i canali da BGR a RGB
img = cv2.imread("./photo.jpg", cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Visualizziamo l'immagine
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.axis('off')
plt.show()
Otteniamo la segmentazione¶
import torchvision
import numpy as np
# Instanziamo il modello da torchvision
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# Settiamo il modello in "eval" mode in quanto il comportamento durante il training differisce da quello di evaluation
model.eval()
# Trasformiamo l'immagine in tensore
image_tensor = torchvision.transforms.functional.to_tensor(img)
# Passiamo l'immagine alla rete
output = model([image_tensor])
# Recuperiamo la maschera...cosa succede se ci sono più persone?
mask = output[0]['masks'].detach().numpy()[0,0,:]
# Teniamo solamente i pixel corrispondenti alla segmentazione
img[:,:,0] = img[:,:,0]*mask
img[:,:,1] = img[:,:,1]*mask
img[:,:,2] = img[:,:,2]*mask
# Visualizziamo l'immagine
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.axis('off')
plt.show()
Background spaziale¶
!wget https://thebuzzpaper.com/wp-content/uploads/2019/11/space-signals-3246.jpg -O bground.jpg
# Prendiamo il background e invertiamo i canali da BGR a RGB
bground = cv2.imread("./bground.jpg", cv2.IMREAD_UNCHANGED)
bground = cv2.cvtColor(bground, cv2.COLOR_BGR2RGB)
# Resize dell'immagine all'immagine della persona
bground = cv2.resize(bground, (img.shape[1],img.shape[0]))
#----------------- Visualizziamo l'immagine di background ----------------------
# Visualizziamo l'immagine
plt.figure(figsize=(10,10))
plt.imshow(bground)
plt.axis('off')
plt.show()
#-------------------------------------------------------------------------------
Teletrasportiamoci!¶
#------------ Riusciresti a teletrasportare la persona nello spazio? -----------
not_mask = mask < 0.5
not_mask = not_mask.astype(np.float)
bground[:,:,0] = bground[:,:,0]*not_mask
bground[:,:,1] = bground[:,:,1]*not_mask
bground[:,:,2] = bground[:,:,2]*not_mask
risultato = bground + img
#-------------------------------------------------------------------------------
# Visualizziamo il risultato
plt.figure(figsize=(10,10))
plt.imshow(risultato)
plt.axis('off')
plt.show()